# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn

from benchmark_examples.autoattack.applications.base import (
    ApplicationBase,
    ClassficationType,
    InputMode,
    ModelType,
)
from benchmark_examples.autoattack.attacks.base import AttackBase, AttackType
from benchmark_examples.autoattack.utils.resources import ResourcesPack
from secretflow.ml.nn.callbacks.attack import AttackCallback
from secretflow.ml.nn.sl.attacks.exploitattack_torch import ExploitAttack


class BinaryClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(BinaryClassifier, self).__init__()
        self.layer1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(64, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.sigmoid(x)
        return x


def exploit_auxiliary_surrogate_model(input_dim, output_dim=1):
    def create_model():
        model = BinaryClassifier(input_dim, output_dim)
        return model

    return create_model


class ExploitAttackCase(AttackBase):
    """
    Exploit attack needs:
    - app.exploit_label_counts() to return neg and pos in train labels
    """

    def __str__(self):
        return "exploit"

    def build_attack_callback(self, app: ApplicationBase) -> AttackCallback:
        neg, pos = app.get_train_lable_neg_pos_counts()
        total = neg + pos
        pos = torch.tensor(
            pos
        )  # Assuming pos is defined and is a tensor or convertible to a tensor
        total = torch.tensor(
            total
        )  # Assuming total is defined and is a tensor or convertible to a tensor

        p_pos = pos / total
        p_values = torch.tensor([1 - p_pos, p_pos], dtype=torch.float32)

        H_y = -p_pos * torch.log(p_pos) - (1 - p_pos) * torch.log(1 - p_pos)
        y_pri = p_values.unsqueeze(0)  # Adds an extra dimension at axis 0
        return ExploitAttack(
            attack_party=app.device_f,
            batch_size=app.train_batch_size,
            epochs=1,
            surrogate_model=exploit_auxiliary_surrogate_model(app.hidden_size, 1),
            y_pri=y_pri,
            H_y=H_y,
        )

    def attack_type(self) -> AttackType:
        return AttackType.LABLE_INFERENSE

    def tune_metrics(self):
        return {'accuracy': 'max'}

    def check_app_valid(self, app: ApplicationBase) -> bool:
        return (
            app.model_type() in [ModelType.DNN]
            and app.classfication_type() in [ClassficationType.BINARY]
            and app.base_input_mode() == InputMode.SINGLE
        )

    def update_resources_consumptions(
        self, cluster_resources_pack: ResourcesPack, app: ApplicationBase
    ) -> ResourcesPack:
        func = lambda x: x * 1.05
        return cluster_resources_pack.apply_debug_resources(
            'gpu_mem', func
        ).apply_sim_resources(app.device_f.party, 'gpu_mem', func)
